{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div class=\"bk-root\">\n",
       "        <a href=\"https://bokeh.pydata.org\" target=\"_blank\" class=\"bk-logo bk-logo-small bk-logo-notebook\"></a>\n",
       "        <span id=\"39688ee8-d9fd-46bd-b2c2-7a3ab84d6e58\">Loading BokehJS ...</span>\n",
       "    </div>"
      ]
     },
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {},
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import bokeh.plotting as bk\n",
    "from bokeh.io import output_notebook\n",
    "output_notebook()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../') # go to parent dir\n",
    "import numpy as np\n",
    "np.set_printoptions(precision=3, formatter={'float': '{: 0.3f}'.format})\n",
    "np.seterr('warn')\n",
    "import os\n",
    "import cProfile, pstats, io"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from bokeh.palettes import Category20\n",
    "colors = Category20[20]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = os.path.join(os.getcwd(), 'data')\n",
    "frame_start, frame_end = -1, 100\n",
    "measurements = np.load(os.path.join(path, 'simulated_data' + '.npy'))\n",
    "gt_bboxes = np.load(os.path.join(path, 'gt_bboxes' + '.npy'))\n",
    "if frame_end is -1:\n",
    "    measurements = measurements[measurements['ts'] >= frame_start]\n",
    "    gt_bboxes = gt_bboxes[gt_bboxes['ts'] >= frame_start]\n",
    "else:\n",
    "    measurements = measurements[ (measurements['ts'] >= frame_start) & (measurements['ts'] <= frame_end)]\n",
    "    gt_bboxes = gt_bboxes[ (gt_bboxes['ts'] >= frame_start) & (gt_bboxes['ts'] <= frame_end)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Spline Tracker"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create Tracker"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "steps = max(measurements['ts']) + 1\n",
    "\n",
    "# tracker config\n",
    "dt = 0.04\n",
    "sa_sq = 1.5 ** 2\n",
    "somega_sq = 0.05 ** 2\n",
    "q = np.zeros((5 + 2, 5 + 2), dtype='f4')\n",
    "q[2, 2] = somega_sq * dt ** 3 / 3.0\n",
    "q[2, 4] = somega_sq * dt ** 2 / 2.0\n",
    "q[4, 2] = q[2, 4]\n",
    "q[4, 4] = somega_sq * dt\n",
    "q[3, 3] = sa_sq * dt\n",
    "q[0, 0] = 1e-3 * dt\n",
    "q[1, 1] = 1e-3 * dt\n",
    "q[5, 5] = 1e-5 * dt\n",
    "q[6, 6] = q[5, 5]\n",
    "\n",
    "config = {\n",
    "    'steps': steps + 1,\n",
    "    'd': 2,\n",
    "    'sd': 7,\n",
    "    'alpha': 1.0,\n",
    "    'beta': 2.0,\n",
    "    'kappa': 2.0,\n",
    "    'sa_sq': sa_sq,\n",
    "    'somega_sq': somega_sq,\n",
    "    'q': q,\n",
    "    'r': 0.05 ** 2 * np.identity(2),\n",
    "    'init_m': np.asarray([6.5, 2.5, 0.00, 12, 0, 1, 1]),\n",
    "    'p_basis': np.array([\n",
    "        [2.5, 0.0],\n",
    "        [2.5, 1.0],\n",
    "        [0.0, 1.0],\n",
    "        [-2.5, 1.0],\n",
    "        [-2.5, 0.0],\n",
    "        [-2.5, -1.0],\n",
    "        [0.0, -1.0],\n",
    "        [2.5, -1.0],\n",
    "    ]),\n",
    "    'init_c': 2 ** 2 * q + np.diag([1.25, 1.25, 0.05, 0.0, 0.0, 0.0, 0.0]),\n",
    "    'scale_correction': True,\n",
    "    'orientation_correction': True,\n",
    "}\n",
    "\n",
    "from models import UkfSplineTracker as Tracker"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# sys.path\n",
    "# __file__"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run Tracker"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total time: 3.462291, total steps: 97, average step time: 0.035694\n",
      "         70100 function calls in 3.462 seconds\n",
      "\n",
      "   Ordered by: cumulative time\n",
      "   List reduced from 92 to 10 due to restriction <10>\n",
      "\n",
      "   ncalls  tottime  percall  cumtime  percall filename:lineno(function)\n",
      "      100    0.007    0.000    3.462    0.035 ../tracker.py:47(step)\n",
      "      100    0.611    0.006    3.312    0.033 ../models/spline.py:738(correct)\n",
      "      100    1.282    0.013    1.290    0.013 /Users/Jens/anaconda3/lib/python3.6/site-packages/numpy/linalg/linalg.py:464(inv)\n",
      "      500    0.751    0.002    1.045    0.002 /Users/Jens/anaconda3/lib/python3.6/site-packages/numpy/lib/function_base.py:1044(average)\n",
      "     2200    0.264    0.000    0.264    0.000 {method 'reduce' of 'numpy.ufunc' objects}\n",
      "     1000    0.004    0.000    0.246    0.000 {method 'sum' of 'numpy.ndarray' objects}\n",
      "     1100    0.001    0.000    0.244    0.000 /Users/Jens/anaconda3/lib/python3.6/site-packages/numpy/core/_methods.py:31(_sum)\n",
      "      700    0.188    0.000    0.188    0.000 {built-in method numpy.core.multiarray.dot}\n",
      "      300    0.018    0.000    0.171    0.001 /Users/Jens/anaconda3/lib/python3.6/site-packages/numpy/core/einsumfunc.py:819(einsum)\n",
      "      100    0.020    0.000    0.143    0.001 ../models/spline.py:709(predict)\n",
      "\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# tracker definition\n",
    "tracker = Tracker(dt=dt, **config)\n",
    "\n",
    "pr = cProfile.Profile()\n",
    "for i in range(steps):\n",
    "    scan = measurements[measurements['ts'] == i]\n",
    "    pr.enable()\n",
    "    tracker.step(scan)\n",
    "    pr.disable()\n",
    "\n",
    "estimates, log_lik = tracker.extract()\n",
    "bboxes = tracker.extrackt_bbox()\n",
    "\n",
    "s = io.StringIO()\n",
    "ps = pstats.Stats(pr, stream=s).sort_stats('cumulative')\n",
    "ps.print_stats(10)\n",
    "\n",
    "print('total time: {:f}, total steps: {:d}, average step time: {:f}'.format(\n",
    "    ps.total_tt, max(measurements['ts']) - 2, ps.total_tt / (max(measurements['ts']) - 2)))\n",
    "print(s.getvalue())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Calculate Wasserstein Metric"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from metric import point_set_wasserstein_distance as pt_wsd\n",
    "from misc import convert_rectangle_to_eight_point\n",
    "eight_pts = convert_rectangle_to_eight_point(bboxes[1:])  # drop prior bounding box\n",
    "eight_pts_gt = convert_rectangle_to_eight_point(gt_bboxes)\n",
    "\n",
    "wsd = np.zeros(len(gt_bboxes), dtype='f8')\n",
    "for i, (pts, gt_pts) in enumerate(zip(eight_pts, eight_pts_gt)):\n",
    "    wsd[i] = pt_wsd(pts, gt_pts, p=2.0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### data source version"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from bokeh.models import ColumnDataSource\n",
    "\n",
    "stride = 5\n",
    "\n",
    "est_source = ColumnDataSource(\n",
    "        data={\n",
    "            'ts': estimates['ts'],\n",
    "            'm_x': estimates['m'][:, 0],\n",
    "            'm_y': estimates['m'][:, 1],\n",
    "            'phi': estimates['m'][:, 2],\n",
    "            'v': estimates['m'][:, 3],\n",
    "            'd_x': estimates['m'][:, -2],\n",
    "            'd_y': estimates['m'][:, -1],\n",
    "            'loglik': log_lik,\n",
    "        }\n",
    "    )\n",
    "\n",
    "meas_source = ColumnDataSource(\n",
    "        data={\n",
    "            'ts': measurements['ts'],\n",
    "            'x': measurements['xy'][:, 0],\n",
    "            'y': measurements['xy'][:, 1],\n",
    "        }\n",
    ")\n",
    "\n",
    "sel = measurements['ts'] % stride == 0\n",
    "meas_source_sel = ColumnDataSource(\n",
    "        data={\n",
    "            'ts': measurements['ts'][sel],\n",
    "            'x': measurements['xy'][sel, 0],\n",
    "            'y': measurements['xy'][sel, 1],\n",
    "        }\n",
    ")\n",
    "\n",
    "sel = bboxes['ts'] % stride == 0\n",
    "bbox_source = ColumnDataSource(\n",
    "        data={\n",
    "            'ts': bboxes['ts'][sel],\n",
    "            'm_x': bboxes['center_xy'][sel, 0],\n",
    "            'm_y': bboxes['center_xy'][sel, 1],\n",
    "            'phi': bboxes['orientation'][sel],\n",
    "            'l': bboxes['dimension'][sel, 0],\n",
    "            'w': bboxes['dimension'][sel, 1],\n",
    "        }\n",
    ")\n",
    "\n",
    "gt_bbox_source = ColumnDataSource(\n",
    "        data={\n",
    "            'ts': gt_bboxes['ts'],\n",
    "            'm_x': gt_bboxes['center_xy'][:, 0],\n",
    "            'm_y': gt_bboxes['center_xy'][:, 1],\n",
    "            'phi': gt_bboxes['orientation'],\n",
    "            'l': gt_bboxes['dimension'][:, 0],\n",
    "            'w': gt_bboxes['dimension'][:, 1],\n",
    "            'wsd': wsd,\n",
    "        }\n",
    ")\n",
    "\n",
    "sel = gt_bboxes['ts'] % stride == 0\n",
    "gt_bbox_source_sel = ColumnDataSource(\n",
    "        data={\n",
    "            'ts': gt_bboxes['ts'][sel],\n",
    "            'm_x': gt_bboxes['center_xy'][sel, 0],\n",
    "            'm_y': gt_bboxes['center_xy'][sel, 1],\n",
    "            'phi': gt_bboxes['orientation'][sel],\n",
    "            'l': gt_bboxes['dimension'][sel, 0],\n",
    "            'w': gt_bboxes['dimension'][sel, 1],\n",
    "        }\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def periodic_spline_pts_from_knots(p_base, num_pts, m_base, tau_vec):\n",
    "    \"\"\"\n",
    "    calculates points from basis. output is 2xlen(p_base)*num_pts + 1\n",
    "    \"\"\"\n",
    "\n",
    "    num_base_pts = len(p_base)\n",
    "    sel = np.arange(3, dtype=np.int32)[:, None] + np.arange(num_base_pts, dtype=np.int32)\n",
    "    sel[:, num_base_pts - 2:] %= num_base_pts\n",
    "    p_base_sel = p_base[sel.T]\n",
    "\n",
    "    mp = np.dot(m_base, p_base_sel)\n",
    "    pts = np.dot(mp.T, tau_vec)\n",
    "    pts.shape = (2, -1)\n",
    "    return pts\n",
    "\n",
    "from numba import njit, float64, int64, float32, int32\n",
    "\n",
    "@njit\n",
    "def jit_periodic_spline_pts_from_knots(p_base, num_pts, m_base, tau_vec):\n",
    "    \"\"\"\n",
    "    calculates points from basis. output is 2xlen(p_base)*num_pts + 1\n",
    "    \"\"\"\n",
    "    num_base_pts = len(p_base)\n",
    "    p_base_sel = np.empty((num_base_pts, 3, 2), dtype=p_base.dtype)\n",
    "    for i in range(num_base_pts):\n",
    "        for j in range(3):\n",
    "            for k in range(2):\n",
    "                p_base_sel[i, j, k] = p_base[(i + j) % num_base_pts, k]\n",
    "                \n",
    "    num_int_p = tau_vec.shape[-1]\n",
    "    q = np.zeros((2, num_base_pts, num_int_p), dtype=p_base.dtype)\n",
    "    \n",
    "    for k in range(2): # x, y\n",
    "        for n in range(num_base_pts):\n",
    "            for l in range(num_int_p):\n",
    "                for i in range(3):\n",
    "                    for j in range(3):\n",
    "                        q[k, n, l] += tau_vec[i, l] * m_base[i, j] * p_base_sel[n, j, k]\n",
    "    return q"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "<div class=\"bk-root\">\n",
       "    <div class=\"bk-plotdiv\" id=\"e4b4059b-e9dd-473c-8344-2987e73ebd91\"></div>\n",
       "</div>"
      ]
     },
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {},
     "metadata": {
      "application/vnd.bokehjs_exec.v0+json": {
       "id": "d6bff355-d055-41e9-a174-88b99f64ccba"
      }
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "from bokeh.layouts import gridplot\n",
    "from bokeh.plotting import figure, output_file, show\n",
    "from bokeh.models import HoverTool, BoxSelectTool, PanTool, BoxZoomTool, WheelZoomTool, UndoTool, RedoTool, ResetTool, SaveTool\n",
    "\n",
    "TOOLS = [PanTool(), BoxSelectTool(), BoxZoomTool(), WheelZoomTool(), UndoTool(), RedoTool(), ResetTool(), SaveTool()]\n",
    "\n",
    "top = figure(tools=TOOLS, width=800, height=350, title=None, match_aspect=True, output_backend=\"webgl\")\n",
    "meas_scatter = top.circle('x', 'y', color='#303030', legend='measurements', size=1, alpha=0.2, source=meas_source)\n",
    "track_plot = top.line('m_x', 'm_y', legend='track', source=est_source)\n",
    "hover = HoverTool(renderers=[track_plot],\n",
    "        tooltips=[\n",
    "            (\"index\", \"$index\"),\n",
    "            (r\"(x, y, phi, v)\", \"(@m_x, @m_y, @phi, @v)\"),\n",
    "            (\"(length, width)\", \"(@d_x, @d_y)\"),\n",
    "        ]\n",
    "    )\n",
    "top.add_tools(hover)\n",
    "\n",
    "bottom = figure(tools=TOOLS, width=800, height=350, title=None, \n",
    "                x_range=top.x_range, y_range=top.y_range, output_backend=\"webgl\")\n",
    "bottom.circle('x', 'y', color='#303030', legend='measurements', size=1, alpha=0.2, source=meas_source_sel)\n",
    "bottom.rect(x='m_x', y='m_y', width='l', height='w', angle='phi', color=\"#CAB2D6\", \n",
    "            fill_alpha=0.2, source=bbox_source, legend='bounding boxes')\n",
    "bottom.rect(x='m_x', y='m_y', width='l', height='w', angle='phi', color='#98df8a', \n",
    "            fill_alpha=0.2, source=gt_bbox_source_sel, legend='ground truth boxes')\n",
    "\n",
    "sel = estimates['ts'] % stride == 0\n",
    "m_base = np.array([[0.5, -1.0, 0.5],\n",
    "                   [-1.0, 1.0, 0.0],\n",
    "                   [0.5, 0.5, 0.0]], dtype='f4')\n",
    "num_pts= 10\n",
    "tau = np.linspace(0, 1, num_pts, endpoint=False)\n",
    "tau_vec = np.vstack((tau ** 2, tau, np.ones(num_pts)))\n",
    "for est in estimates[sel]:\n",
    "    \n",
    "    s_phi, c_phi = np.sin(est['m'][2]), np.cos(est['m'][2])\n",
    "    rot = np.asarray([[c_phi, -s_phi],\n",
    "                      [s_phi, c_phi]])\n",
    "    \n",
    "    q = jit_periodic_spline_pts_from_knots(est['p_basis'], 10, m_base, tau_vec)\n",
    "    q.shape = (2, -1)\n",
    "    q = np.dot(rot, q) + est['m'][:2, None]    \n",
    "    bottom.line(x=q[0], y=q[1], color=colors[0], legend='shape')\n",
    "\n",
    "bottom.legend.click_policy=\"hide\"\n",
    "\n",
    "p = gridplot([[top], [bottom]])\n",
    "show(p)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Statistics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "<div class=\"bk-root\">\n",
       "    <div class=\"bk-plotdiv\" id=\"f283fea6-eaef-4f57-afb3-349db74ac08a\"></div>\n",
       "</div>"
      ]
     },
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {},
     "metadata": {
      "application/vnd.bokehjs_exec.v0+json": {
       "id": "2337c1bc-7bb4-4c5b-b1ed-6c942ad91262"
      }
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "from bokeh.models.widgets import Panel, Tabs\n",
    "from bokeh.layouts import gridplot, column\n",
    "\n",
    "f_wsd = figure(tools=TOOLS, width=800, height=350, title='Wasserstein Distance')\n",
    "f_log_lik = figure(tools=TOOLS, width=800, height=350, title='Likelihood', x_range=f_wsd.x_range)\n",
    "f_wsd.line('ts', 'wsd', legend='wasserstein distance', source=gt_bbox_source, color=colors[0])\n",
    "f_log_lik.line('ts', 'loglik', legend='log likelihood', source=est_source, color=colors[0])\n",
    "pan1 = Panel(child=column([f_wsd, f_log_lik]), title=\"Likelihood and Wasserstein Distance\")\n",
    "\n",
    "f_position = figure(tools=TOOLS, width=800, height=350, title='Position')\n",
    "f_orientation = figure(tools=TOOLS, width=800, height=350, title='Orientation', x_range=f_position.x_range)\n",
    "f_velocity = figure(tools=TOOLS, width=800, height=350, title='Velocity', x_range=f_position.x_range)\n",
    "f_dimension = figure(tools=TOOLS, width=800, height=350, title='Dimension', x_range=f_position.x_range)\n",
    "f_position.line('ts', 'm_x', legend='x position', source=est_source, color=colors[0])\n",
    "f_position.line('ts', 'm_y', legend='y position', source=est_source, color=colors[2])\n",
    "f_position.line('ts', 'm_x', legend='ground truth x position', source=gt_bbox_source, color=colors[1])\n",
    "f_position.line('ts', 'm_y', legend='ground truth y position', source=gt_bbox_source, color=colors[3])\n",
    "f_orientation.line('ts', 'phi', legend='orientation', source=est_source, color=colors[0])\n",
    "f_orientation.line('ts', 'phi', legend='ground truth orientation', source=gt_bbox_source, color=colors[1])\n",
    "f_velocity.line('ts', 'v', legend='velocity', source=est_source, color=colors[0])\n",
    "# f_velocity.line('ts', 'v', legend='ground truth orientation', source=gt_bbox_source, color=colors[3])\n",
    "f_dimension.line('ts', 'l', legend='x dimension scaling', source=bbox_source, color=colors[0])\n",
    "f_dimension.line('ts', 'w', legend='y dimension scaling', source=bbox_source, color=colors[2])\n",
    "f_dimension.line('ts', 'l', legend='ground truth x dimension', source=gt_bbox_source, color=colors[1])\n",
    "f_dimension.line('ts', 'w', legend='ground truth y dimension', source=gt_bbox_source, color=colors[3])\n",
    "pan2 = Panel(child=column([f_position, f_orientation, f_velocity, f_dimension]), title=\"Kinematic and Extent\")\n",
    "\n",
    "for f in [f_wsd, f_log_lik, f_position, f_orientation, f_velocity, f_dimension]:\n",
    "    f.legend.click_policy=\"mute\"\n",
    "\n",
    "tabs = Tabs(tabs=[pan1, pan2])\n",
    "show(tabs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
